import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.preprocessing import StandardScaler, OneHotEncoder
import gc
import torch
from torch_geometric.data import Data
import time
import torch_geometric.transforms as T
from torch_geometric.utils import dense_to_sparse

class DataSpace:
    def __init__(self, info, data):
        """
        Generating training / validation / testing data.
        Parameters:
        ----------
        info: dict
            The eda infomation generated by AutoEDA
        data: dict
            The original data passed by the ingestion program.
        ----------
        """
        self.info = info

        self.y = data['train_label']['label'].to_numpy()
        self.pyg_data, self.all_train_idxs, self.test_idxs = self.generate_pyg_data(data)
        self.splits = {}
        self.n_splits = 5
        self.split_train_valid(ratio=0.1)
        self.update = False

    def split_train_valid(self, ratio=0.1):
        sss = StratifiedShuffleSplit(n_splits=self.n_splits, test_size=ratio, random_state=0)
        i = 0
        for train, val in sss.split(self.all_train_idxs, self.y):
            self.splits[i] = (self.all_train_idxs[train], self.all_train_idxs[val])
            i += 1
    
    def get_data(self, round_num):
        train_idxs, val_idxs = self.splits[(round_num-1) % self.n_splits]
        print(f'Round {round_num}')

        train_mask = torch.zeros(self.pyg_data.num_nodes, dtype=torch.bool)
        train_mask[train_idxs] = 1
        self.pyg_data.train_mask = train_mask

        valid_mask = torch.zeros(self.pyg_data.num_nodes, dtype=torch.bool)
        valid_mask[val_idxs] = 1
        self.pyg_data.valid_mask = valid_mask

        return self.pyg_data
  
    def generate_pyg_data(self, data):
        x = data['fea_table']
        x = x.drop('node_index', axis=1).to_numpy()
        x = torch.tensor(x, dtype=torch.float)

        df = data['edge_file']
        edge_index = df[['src_idx', 'dst_idx']].to_numpy()
        edge_index = sorted(edge_index, key=lambda d: d[0])
        edge_index = torch.tensor(edge_index, dtype=torch.long).transpose(0, 1)

        edge_weight = df['edge_weight'].to_numpy()
        edge_weight = torch.tensor(edge_weight, dtype=torch.float32)

        num_nodes = x.size(0)

        y = torch.zeros(num_nodes, dtype=torch.long)
        inds = data['train_label'][['node_index']].to_numpy()
        train_y = data['train_label'][['label']].to_numpy()
        y[inds] = torch.tensor(train_y, dtype=torch.long)

        all_train_idxs = np.array(data['train_indices'], dtype=int)
        test_idxs = np.array(data['test_indices'], dtype=int)

        data = Data(x=x, edge_index=edge_index, y=y, edge_weight=edge_weight)
        data.num_nodes = num_nodes

        data.test_idxs = test_idxs

        test_mask = torch.zeros(num_nodes, dtype=torch.bool)
        test_mask[test_idxs] = 1
        data.test_mask = test_mask

        data.label_weights = self.info['label_weights']

        if self.info['normalize_features'] == 'row':
            print('Feature Normalized By Row')
            data.x = data.x / data.x.sum(1, keepdim=True).clamp(min=1)
        elif self.info['normalize_features'] == 'col':
            print('Feature Normalized By Column')
            data.x = data.x / data.x.sum(0, keepdim=True).clamp(min=1)

        return data.to('cuda'), all_train_idxs, test_idxs
